{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example of using Triton Server Wrapper in Jupyter notebook"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Triton server setup with custom model"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "Install dependencies"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import sys\n",
    "!{sys.executable} -m pip install numpy\n",
    "!{sys.executable} -m pip install cupy-cuda12x --extra-index-url=https://pypi.ngc.nvidia.com"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Required imports:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from pytriton.decorators import batch\n",
    "from pytriton.model_config import ModelConfig, Tensor\n",
    "from pytriton.triton import Triton"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define inference callable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@batch\n",
    "def _add_sub(**inputs):\n",
    "    a_batch, b_batch = inputs.values()\n",
    "    add_batch = a_batch + b_batch\n",
    "    sub_batch = a_batch - b_batch\n",
    "    return {\"add\": add_batch, \"sub\": sub_batch}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instantiate titon wrapper class and load model with defined callable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "triton = Triton()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "triton.bind(\n",
    "        model_name=\"AddSub\",\n",
    "        infer_func=_add_sub,\n",
    "        inputs=[\n",
    "            Tensor(dtype=np.float32, shape=(-1,)),\n",
    "            Tensor(dtype=np.float32, shape=(-1,)),\n",
    "        ],\n",
    "        outputs=[\n",
    "            Tensor(name=\"add\", dtype=np.float32, shape=(-1,)),\n",
    "            Tensor(name=\"sub\", dtype=np.float32, shape=(-1,)),\n",
    "        ],\n",
    "        config=ModelConfig(max_batch_size=128),\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run triton server with defined model inference callable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "triton.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example inference performed with ModelClient calling triton server"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytriton.client import ModelClient\n",
    "batch_size = 2\n",
    "a_batch = np.ones((batch_size, 1), dtype=np.float32)\n",
    "b_batch = np.ones((batch_size, 1), dtype=np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with ModelClient(\"localhost\", \"AddSub\") as client:\n",
    "    result_batch = client.infer_batch(a_batch, b_batch)\n",
    "\n",
    "for output_name, data_batch in result_batch.items():\n",
    "    print(f\"{output_name}: {data_batch.tolist()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Re-setup triton server with modified inference callable"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Stop triton server"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "triton.stop()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Redefine inference callable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@batch\n",
    "def _add_sub(**inputs):\n",
    "    a_batch, b_batch = inputs.values()\n",
    "    add_batch = (a_batch + b_batch) * 2\n",
    "    sub_batch = (a_batch - b_batch) * 3\n",
    "    return {\"add\": add_batch, \"sub\": sub_batch}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load model again"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "triton.bind(\n",
    "        model_name=\"AddSub\",\n",
    "        infer_func=_add_sub,\n",
    "        inputs=[\n",
    "            Tensor(dtype=np.float32, shape=(-1,)),\n",
    "            Tensor(dtype=np.float32, shape=(-1,)),\n",
    "        ],\n",
    "        outputs=[\n",
    "            Tensor(name=\"add\", dtype=np.float32, shape=(-1,)),\n",
    "            Tensor(name=\"sub\", dtype=np.float32, shape=(-1,)),\n",
    "        ],\n",
    "        config=ModelConfig(max_batch_size=128),\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run triton server with new model inference callable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "triton.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The same inference performed with modified inference callable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with ModelClient(\"localhost\", \"AddSub\") as client:\n",
    "    result_batch = client.infer_batch(a_batch, b_batch)\n",
    "\n",
    "for output_name, data_batch in result_batch.items():\n",
    "    print(f\"{output_name}: {data_batch.tolist()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Stop server at the end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "triton.stop()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
